#ehdr_struct = 10sHHIQQQIHHHHHH
import struct
import collections
import sys
import elf

SLL27 = bytearray(struct.pack('>I', 0x1b09644b))
LDI = 0x3e
LDIH = 0x3f
def gen_mem(opc, ra, rb, disp):
    ret = opc << 26 | ra << 21 | rb << 16 | disp & 0xffff
    return bytearray(struct.pack('<I', ret))
def set_disp(orig, disp):
    inst = struct.unpack('<I', orig)[0]
    inst &= 0xffff0000
    inst |= disp & 0xffff
    return bytearray(struct.pack('<I', inst))
def depart_long(x):
    parts = list(struct.unpack('HHHH', struct.pack('Q', x)))
    for i in range(3):
        if parts[i] > 0x7fff:
            parts[i] -= 0x10000
            parts[i + 1] += 1
    parts[3] &= 0xffff
    return parts
    #x0 = x & 0xffff
verbose_log = open("/dev/null", "w")
def post_process_refs(path_in, path_out, verbose=True):
    try:
        print(path_in)
        elf_in = elf.ELF64(open(path_in, "rb").read())
        #print(elf_in.global_syms)
        text1_offset = elf_in.shdrs['.text1'].sh_offset
        text1_addr = elf_in.shdrs['.text1'].sh_addr
        local_symtab = {}
        for sym in elf_in.syms:
          if sym.st_name.startswith("defsym."):
            symsp = sym.st_name.split(".")
            local_symtab[symsp[1] + "." + symsp[2]] = sym.st_value
        for sym in elf_in.syms:
            
            if sym.st_name.startswith("loadsym.") or sym.st_name.startswith("usesym."):
                symsp = sym.st_name.split(".")
                if (not symsp[1].startswith("slave_")):
                    symsp[1] = "slave_" + symsp[1]
                if symsp[1] + "." + symsp[2] in local_symtab:
                    print("%s: is resolved locally as %x" % (sym.st_name, local_symtab[symsp[1] + "." + symsp[2]]), file=verbose_log)
                    dest_addr = local_symtab[symsp[1] + "." + symsp[2]]
                elif symsp[1] in elf_in.global_syms:
                    print("%s: is resolved globally as %x" % (sym.st_name, elf_in.global_syms[symsp[1]].st_value), file=verbose_log)
                    dest_addr = elf_in.global_syms[symsp[1]].st_value
                else:
                    print("unresolved symbol: %s" % symsp[1], file=sys.stderr)
                    if sym.st_name.startswith("loadsym."):
                        return False
                inst_off = (sym.st_value - text1_addr + text1_offset)
                
                if sym.st_name.startswith("loadsym."):
                    #print("%s: %s is resolved as %x" % (sym.st_name, symbol, elf_in.global_syms[symbol].st_value), file=sys.stderr)
                    parts = depart_long(dest_addr)
                    elf_in.bin[inst_off +  0 : inst_off +  4] = set_disp(elf_in.bin[inst_off +  0 : inst_off +  4], parts[2])
                    elf_in.bin[inst_off +  4 : inst_off +  8] = set_disp(elf_in.bin[inst_off +  4 : inst_off +  8], parts[3])
                    elf_in.bin[inst_off + 12 : inst_off + 16] = set_disp(elf_in.bin[inst_off + 12 : inst_off + 16], parts[0])
                    elf_in.bin[inst_off + 16 : inst_off + 20] = set_disp(elf_in.bin[inst_off + 16 : inst_off + 20], parts[1])
                # elif sym.st_name.startswith("usesym."):
                #     disp = dest_addr - inst_off
                #     elf_in.bin[inst_off +  0 : inst_off +  4] = set_disp(elf_in.bin[inst_off +  0 : inst_off +  4], disp & 0xffff)
        print(path_out)
        open(path_out, "wb").write(elf_in.bin)
        return True
    except:
        import traceback
        traceback.print_exc()
        return False
if __name__ == "__main__":
    import sys
    if not post_process_refs(sys.argv[1], sys.argv[2]):
        sys.exit(1)